'''
Author: SlytherinGe
LastEditTime: 2021-04-10 20:20:36
'''
import mmcv
import numpy as np
from mmcv.ops.nms import nms
import torch

def fus_results(input_files, out_file, threshold=0.8):

    assert len(input_files)>1, 'please input more than one result file'

    result_data = []
    for f in input_files:
        result_data.append(mmcv.load(f))
    
    out_data = result_data[0]
    for i in range(1, len(result_data)):

        data2append = result_data[i]
        for p in range(len(out_data)):
            for q in range(len(out_data[p])):
                out_data[p][q] = np.concatenate((out_data[p][q],data2append[p][q]),axis=0)
                if (out_data[p][q].shape)[0] > 0 :
                    bboxes = torch.Tensor(out_data[p][q][:,:4]).contiguous()
                    scores = torch.Tensor(out_data[p][q][:,4]).contiguous()
                    dets, _ = nms(bboxes, scores, threshold)
                    out_data[p][q] = np.array(dets)        

    mmcv.dump(out_data, out_file)

    return out_data


if __name__ == '__main__':

    RESULTS = [[ '/media/gejunyao/Disk/Gejunyao/exp_results/mmdetection_files/vedai_9_cls_retina/vedai_rgb_retinanet/best_results.pkl',
                '/media/gejunyao/Disk/Gejunyao/exp_results/mmdetection_files/vedai_9_cls_retina/vedai_ir_retinanet/best_results.pkl'],
               ]

    OUT = ['/media/gejunyao/Disk/Gejunyao/exp_results/mmdetection_files/vedai_9_cls_retina/vedai_ir_retinanet/fus_dets.pkl',
            ]

    assert len(RESULTS)==len(OUT)

    for i in range(len(OUT)):
        fus_results(RESULTS[i], OUT[i], 0.5)
        print(f'{i+1}/{len(OUT)} finished!')


